def graph_feature_set_scatter(
all_metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
input_sizes: Dict[str, int],
logdir: Optional[Path] = None,
metric_to_plot: str = "Accuracy",
name: Optional[str] = None,
metric_range: Optional[Tuple[float, float]] = None,
assay_task_key: str = ASSAY,
sex_task_key: str = SEX,
cell_type_task_key: str = CELL_TYPE,
verbose: bool = True,
) -> None:
"""
Graphs performance metrics as a scatter plot with modifications.
X-axis: Number of Features (log scale).
Y-axis: Average performance metric (e.g., Accuracy, F1_macro) across folds.
Vertical lines indicate the min/max range across folds.
Color: Bin Size (bp, log scale).
Args:
all_metrics: Nested dict {feature_set: {task_name: {split_name: metric_dict}}}.
input_sizes: Dict {feature_set: num_features}.
logdir: Directory to save figures. If None, display only.
metric_to_plot: The metric key to use for the Y-axis ('Accuracy', 'F1_macro').
name: Optional suffix for figure titles and filenames.
metric_range: Optional tuple (min, max) to set the Y-axis range.
assay_task_key: Key used for the assay prediction task.
sex_task_key: Key used for the sex prediction task.
cell_type_task_key: Key used for the cell type prediction task.
"""
if metric_to_plot not in ["Accuracy", "F1_macro"]:
raise ValueError("metric_to_plot must be 'Accuracy' or 'F1_macro'")
# --- Standard Name Handling (simplified from original) ---
non_standard_names = {ASSAY: f"{ASSAY}_11c", SEX: f"{SEX}_w-mixed"}
# These lists are no longer strictly needed by the simplified lookup, but kept for context
# non_standard_assay_task_names = ["hg38_100kb_all_none"]
# non_standard_sex_task_name = [
# "hg38_100kb_all_none",
# "hg38_regulatory_regions_n30321",
# "hg38_regulatory_regions_n303114",
# ]
# --- Find reference and task names ----
reference_hdf5_type = next(iter(all_metrics), None)
if reference_hdf5_type is None or not all_metrics.get(reference_hdf5_type):
print(
"Warning: Could not determine tasks from all_metrics. Trying default tasks."
)
cleaned_metadata_categories = {assay_task_key, sex_task_key, cell_type_task_key}
else:
metadata_categories = list(all_metrics[reference_hdf5_type].keys())
cleaned_metadata_categories = set()
for cat in metadata_categories:
original_name = cat
for standard, non_standard in non_standard_names.items():
if cat == non_standard:
original_name = standard
break
cleaned_metadata_categories.add(original_name)
# --- Define Bin size categories and Colors ---
bin_category_names = ["1Kb", "10Kb", "100Kb", "1Mb", "10Mb"]
bin_category_values = [1000, 10000, 100 * 1000, 1000 * 1000, 10000 * 1000]
discrete_colors = px.colors.sequential.Viridis_r
color_map = {
name: discrete_colors[i * 2] for i, name in enumerate(bin_category_names)
}
if verbose:
print(f"Plotting for tasks: {list(cleaned_metadata_categories)}")
for category_name in cleaned_metadata_categories:
plot_data_points = []
for feature_set_name_orig in all_metrics.keys():
try:
num_features = input_sizes[feature_set_name_orig]
except KeyError as e:
raise ValueError(
f"Feature set '{feature_set_name_orig}' not found in input_sizes"
) from e
# Parse Bin Size
bin_size = parse_bin_size(feature_set_name_orig)
if bin_size is None:
print(
f"Skipping {feature_set_name_orig}, could not parse numeric bin size."
)
continue
# 3. Get Metric Values (Average, Min, Max)
tasks_dicts = all_metrics[feature_set_name_orig]
# --- Task Name Lookup ---
# 1. Try the standard category name first
# 2. If standard name not found, use non-standard name
task_dict = None
task_name = category_name
if category_name in tasks_dicts:
task_dict = tasks_dicts[category_name]
else:
non_standard_task_name = non_standard_names.get(category_name)
if non_standard_task_name and non_standard_task_name in tasks_dicts:
task_name = non_standard_task_name
task_dict = tasks_dicts[non_standard_task_name]
if task_dict is None:
raise ValueError(
f"Task '{category_name}' not found in feature set '{feature_set_name_orig}'"
)
# --- End Task Name Lookup ---
# Calculate average, min, max metric value across splits
try:
metric_values = []
for split, split_data in task_dict.items():
if metric_to_plot in split_data:
metric_values.append(split_data[metric_to_plot])
else:
print(
f"Warning: Metric '{metric_to_plot}' not found in split '{split}' for {feature_set_name_orig} / {task_name}"
)
if not metric_values:
print(
f"Warning: No metric values found for {feature_set_name_orig} / {task_name} / {metric_to_plot}"
)
continue
avg_metric = np.mean(metric_values)
min_metric = np.min(metric_values)
max_metric = np.max(metric_values)
except Exception as e: # pylint: disable=broad-except
raise ValueError(
f"Error calculating metrics for {feature_set_name_orig} / {task_name}: {e}"
) from e
# Clean feature set name for hover text
clean_name = feature_set_name_orig.replace("_none", "").replace("hg38_", "")
clean_name = re.sub(r"\_[\dmkb]+\_coord", "", clean_name)
# Store data for this point
plot_data_points.append(
{
"bin_size": bin_size,
"num_features": num_features,
"metric_value": avg_metric,
"min_metric": min_metric, # For error bar low
"max_metric": max_metric, # For error bar high
"name": clean_name,
"raw_name": feature_set_name_orig,
}
)
if not plot_data_points:
raise ValueError(
f"No suitable data points found to plot for task: {category_name}"
)
# --- Determine Marker Symbols ---
marker_symbols = []
default_symbol = "circle"
random_symbol = "cross"
for p in plot_data_points:
if "random" in p["raw_name"]:
marker_symbols.append(random_symbol)
else:
marker_symbols.append(default_symbol)
# --- Group Data by Category ---
points_by_category = {name: [] for name in bin_category_names}
for i, point_data in enumerate(plot_data_points):
bin_size = point_data["bin_size"]
assigned_category = None
for cat_name, cat_value in zip(bin_category_names, bin_category_values):
if bin_size == cat_value:
assigned_category = cat_name
break
else:
raise ValueError(f"Could not find category for bin size: {bin_size}")
points_by_category[assigned_category].append(
{
"x": point_data["num_features"], # X is Num Features
"y": point_data["metric_value"],
"error_up": point_data["max_metric"] - point_data["metric_value"],
"error_down": point_data["metric_value"] - point_data["min_metric"],
"text": point_data["name"],
"customdata": [
point_data["min_metric"],
point_data["max_metric"],
point_data["bin_size"],
], # Keep bin size for hover
"symbol": marker_symbols[i], # Assign symbol determined earlier
}
)
# --- Create Figure and Add Traces PER CATEGORY ---
fig = go.Figure()
traces = []
for cat_name in bin_category_names: # Iterate in defined order for legend
points_in_cat = points_by_category[cat_name]
if not points_in_cat:
continue
category_color = color_map[cat_name]
# Extract data for all points in this category
x_vals = [p["x"] for p in points_in_cat]
y_vals = [p["y"] for p in points_in_cat]
error_up_vals = [p["error_up"] for p in points_in_cat]
error_down_vals = [p["error_down"] for p in points_in_cat]
text_vals = [p["text"] for p in points_in_cat]
customdata_vals = [p["customdata"] for p in points_in_cat]
symbols_vals = [p["symbol"] for p in points_in_cat]
trace = go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
name=cat_name,
showlegend=False,
legendgroup=cat_name, # Group legend entries
marker=dict(
color=category_color,
size=15,
symbol=symbols_vals,
line=dict(width=1, color="DarkSlateGrey"),
),
error_y=dict(
type="data",
symmetric=False,
array=error_up_vals,
arrayminus=error_down_vals,
visible=True,
thickness=1.5,
width=15,
color=category_color,
),
text=text_vals,
customdata=customdata_vals,
hovertemplate=(
f"<b>%{{text}}</b><br><br>"
f"Num Features: %{{x:,.0f}}<br>"
f"{metric_to_plot}: %{{y:.4f}}<br>"
f"Bin Size: %{{customdata:,.0f}} bp<br>"
f"{metric_to_plot} Range (10-fold): %{{customdata:.4f}} - %{{customdata:.4f}}"
"<extra></extra>"
),
)
traces.append(trace)
fig.add_traces(traces)
# --- Add Legend ---
# Add a hidden scatter trace with square markers for legend
for cat_name in bin_category_names:
category_color = color_map[cat_name]
legend_trace = go.Scatter(
x=[None],
y=[None],
mode="markers",
name=cat_name,
marker=dict(
color=category_color,
size=15,
symbol="square",
line=dict(width=1, color="DarkSlateGrey"),
),
legendgroup=cat_name,
showlegend=True,
)
fig.add_trace(legend_trace)
# --- Update layout ---
title_name = category_name.replace(CELL_TYPE, "biospecimen")
plot_title = f"{metric_to_plot} vs Number of Features - {title_name}"
if name:
plot_title += f" - {name}"
xaxis_title = "Number of Features (log scale)"
xaxis_type = "log"
yaxis_title = metric_to_plot.replace("_", " ").title()
yaxis_type = "linear"
fig.update_layout(
xaxis_title=xaxis_title,
yaxis_title=yaxis_title,
xaxis_type=xaxis_type,
yaxis_type=yaxis_type,
yaxis_range=metric_range,
width=500,
height=500,
hovermode="closest",
legend_title_text="Bin Size",
title_text=plot_title,
**main_title_settings
)
if category_name == CELL_TYPE:
fig.update_yaxes(range=[0.75, 1.005])
elif category_name == ASSAY:
fig.update_yaxes(range=[0.96, 1.001])
# --- Save or show figure ---
if logdir:
logdir.mkdir(parents=True, exist_ok=True)
# Include "modified" or similar in filename to distinguish
base_name = f"feature_scatter_MODIFIED_v2_{category_name}_{metric_to_plot}"
if name:
base_name += f"_{name}"
html_path = logdir / f"{base_name}.html"
svg_path = logdir / f"{base_name}.svg"
png_path = logdir / f"{base_name}.png"
print(f"Saving modified plot for {category_name} to {html_path}")
fig.write_html(html_path)
fig.write_image(svg_path)
fig.write_image(png_path)
fig.show()